# ------------------------------------------------------------------------------
# Generates model subscores in the validation and test sets
# From: stress_test_medicare repo <https://gitlab.com/labsysmed/zolab-projects/stress_test_medicare/-/blob/master/code/03_analysis/01_build-models/02_tune_gbm.R>
# Updates author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q big -R "rusage[mem=25000]" bash 05_predict_ensemble_components.sh {dataset} {split} {restriction}
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(yaml)
library(data.table)
library(tidyverse)
library(glue)
library(Matrix)
library(glmnet)
library(xgboost)
library(optparse)
library(testit) # assert
library(here) # here() relative filepaths

u <- modules::use(here("lib", "util.R"))

# Command Line Args ------------------------------------------------------------
arg_config <- list(
  make_option("--dataset", type = "character"),
  make_option("--split", type = "character"),
  make_option("--restriction", type = "character")
)
arg_parser <- OptionParser(option_list = arg_config)
arg_list <- parse_args(arg_parser)
split <- arg_list$split

# Helpers ----------------------------------------------------------------------
get_score_name <- function(model_file) {
  restriction_lab <- ifelse(
    arg_list$restriction == "all", "", glue("__{arg_list$restriction}")
  )
  glue("p__{str_remove(model_file, '.rds')}{restriction_lab}")
}

get_score_name_log <- function(model_file) {
  restriction_lab <- ifelse(
    arg_list$restriction == "all", "", glue("__{arg_list$restriction}")
  )
  glue("z__{str_remove(model_file, '.rds')}{restriction_lab}")
}

# Directories ------------------------------------------------------------------
message("Establishing directories...")
paths <- read_yaml(here("lib", "filepaths.yml"))
cohort_dir <- file.path(paths$modeling$dir, "cohorts", arg_list$split)
features_dir <- file.path(paths$features$dir, arg_list$split)
models_dir <- file.path(
  paths$modeling$dir, "models", arg_list$split, arg_list$restriction
)
subscores_dir <- file.path(
  paths$modeling$dir, "subscores", arg_list$split, arg_list$restriction
)
assert("Subscores directory exists", dir.exists(subscores_dir))

# Load Data --------------------------------------------------------------------
message("Loading data...")
message("Dataset: ", arg_list$dataset)
ids <- readRDS(file.path(cohort_dir, glue("{arg_list$dataset}_cohort.rds"))) %>%
  select(ptid, ed_enc_id) %>%
  setDT
x <- readRDS(file.path(features_dir, glue("{arg_list$dataset}_features.rds")))
x <- u$sparsify(x)

if(arg_list$restriction == "dropcc"){
  # drop chief complaint features
  keep_feats <- which(!grepl("ed_enc_t0d", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "justcc"){
  keep_feats <- which(grepl("ed_enc_t0d", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "dem"){
  keep_feats <- which(grepl("dem_", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "enc"){
  keep_feats <- which(grepl("enc_", colnames(x)) & !grepl("_cc_", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "dia"){
  keep_feats <- which(grepl("dia_", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "lab"){
  keep_feats <- which(grepl("lab_", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "lvs"){
  keep_feats <- which(grepl("lvs_", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "med"){
  keep_feats <- which(grepl("med_", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "prc"){
  keep_feats <- which(grepl("prc_", colnames(x)))
  x <- x[, keep_feats]
}else if(arg_list$restriction == "represent"){
  representative_vars <- readRDS(paths$analysis$representative_vars)
  keep_feats <- which(colnames(x) %in% representative_vars)
  x <- x[, keep_feats]
}

# Locate Models ----------------------------------------------------------------
message("Locating model files...")
model_files <- list.files(models_dir)
gbm_files <- str_subset(model_files, "gbm__")
lasso_files <- str_subset(model_files, "lasso__")

# Predict GBM ------------------------------------------------------------------
message("Predicting with GBMs...")
gbm_models <- map(
  file.path(models_dir, gbm_files),
  xgb.load
)
walk2(
  get_score_name(gbm_files), gbm_models,
  ~ ids[, (.x) := predict(.y, x, outputmargin = FALSE)]
)
walk2(
  get_score_name_log(gbm_files), gbm_models,
  ~ ids[, (.x) := predict(.y, x, outputmargin = TRUE)]
)
rm(gbm_models)

# Predict Lasso ----------------------------------------------------------------
message("Predicting with lasso...")
lasso_models <- map(
  file.path(models_dir, lasso_files),
  readRDS
)
walk2(
  get_score_name(lasso_files), lasso_models,
  ~ ids[, (.x) := predict(.y, x, s = "lambda.min", type = "response")]
)
walk2(
  get_score_name_log(lasso_files), lasso_models,
  ~ ids[, (.x) := predict(.y, x, s = "lambda.min", type = "link")]
)
rm(lasso_models)

# Save Data --------------------------------------------------------------------
message("Saving results...")
write_rds(
  ids,
  file.path(
    subscores_dir,
    glue("subscores_{arg_list$dataset}_set.rds")
  )
)

# Done -------------------------------------------------------------------------
message("Done.")
